{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第 8 章：卷积神经网络 — TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 载入数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-3f5d8b810895>:1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "mnist = input_data.read_data_sets('./datasets/ch08/tensorflow/MNIST',one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(55000, 784)\n",
      "(55000, 10)\n"
     ]
    }
   ],
   "source": [
    "print(mnist.train.images.shape)\n",
    "print(mnist.train.labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 784)\n",
      "(5000, 10)\n"
     ]
    }
   ],
   "source": [
    "print(mnist.validation.images.shape)\n",
    "print(mnist.validation.labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 784)\n",
      "(10000, 10)\n"
     ]
    }
   ],
   "source": [
    "print(mnist.test.images.shape)\n",
    "print(mnist.test.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义创建模型需要的函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input 代表输入，filter 代表卷积核\n",
    "# 卷积层\n",
    "def conv2d(x, filter):\n",
    "    return tf.nn.conv2d(x, \n",
    "                        filter, \n",
    "                        strides=[1,1,1,1], \n",
    "                        padding='SAME')\n",
    "\n",
    "# 池化层\n",
    "def max_pool(x):\n",
    "    return tf.nn.max_pool(x, \n",
    "                          ksize=[1,2,2,1], \n",
    "                          strides=[1,2,2,1], \n",
    "                          padding='SAME')\n",
    "\n",
    "# 初始化卷积核或者是权重数组的值\n",
    "def weight_variable(shape):\n",
    "    initial = tf.truncated_normal(shape, stddev=0.1)\n",
    "    return tf.Variable(initial)\n",
    "\n",
    "# 初始化bias的值\n",
    "def bias_variable(shape):\n",
    "    initial = tf.constant(0.1, shape=shape)\n",
    "    return tf.Variable(initial)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义占位符 placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# None　代表图片数量未知\n",
    "x = tf.placeholder(tf.float32, [None,784])/255\n",
    "# 将input 重新调整结构，适用于CNN的特征提取\n",
    "x_image = tf.reshape(x, [-1,28,28,1])\n",
    "\n",
    "# y是最终预测的结果\n",
    "y = tf.placeholder(tf.float32, [None,10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义 CNN 模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CONV1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_conv1 = weight_variable([5, 5, 1, 32])  # 卷积核尺寸：5x5, 输入通道：1, 输出通道：32\n",
    "b_conv1 = bias_variable([32])\n",
    "h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)  # 输出尺寸：28x28x32\n",
    "h_pool1 = max_pool(h_conv1)     # 输出尺寸：14x14x32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CONV2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_conv2 = weight_variable([5, 5, 32, 64])  # 卷积核尺寸：5x5, 输入通道：32, 输出通道：64\n",
    "b_conv2 = bias_variable([64])\n",
    "h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)  # 输出尺寸：14x14x64\n",
    "h_pool2 = max_pool(h_conv2)     # 输出尺寸：7x7x64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FC1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_fc1 = weight_variable([7*7*64, 1024])\n",
    "b_fc1 = bias_variable([1024])\n",
    "h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])  # 展开为一维向量\n",
    "h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FC2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_fc2 = weight_variable([1024, 10])\n",
    "b_fc2 = bias_variable([10])\n",
    "prediction = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数与优化算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(prediction),reduction_indices=[1]))   \n",
    "train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "#判断预测标签和实际标签是否匹配\n",
    "correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction,\"float\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练并验证准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train accuracy 0.810\n",
      "train accuracy 0.890\n",
      "train accuracy 0.880\n",
      "train accuracy 0.960\n",
      "train accuracy 0.950\n",
      "train accuracy 0.940\n",
      "train accuracy 0.950\n",
      "train accuracy 0.940\n",
      "train accuracy 0.940\n",
      "train accuracy 0.960\n",
      "train accuracy 0.980\n",
      "train accuracy 0.950\n",
      "train accuracy 0.940\n",
      "train accuracy 0.980\n",
      "train accuracy 0.980\n",
      "train accuracy 0.980\n",
      "train accuracy 1.000\n",
      "train accuracy 0.990\n",
      "train accuracy 0.990\n",
      "train accuracy 0.990\n",
      "test accuracy 0.977\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)\n",
    "\n",
    "cost = []\n",
    "\n",
    "for i in range(1000):\n",
    "    batch_x, batch_y = mnist.train.next_batch(100)\n",
    "    sess.run(train_step, feed_dict={x: batch_x, y: batch_y})\n",
    "    if (i+1) % 50 == 0:\n",
    "        print(\"train accuracy %.3f\" % accuracy.eval(session = sess,\n",
    "                feed_dict = {x:batch_x, y:batch_y}))\n",
    "print(\"test accuracy %.3f\" % accuracy.eval(session = sess,\n",
    "    feed_dict = {x:mnist.test.images, y:mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
